{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\s1612415\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets\n",
    "from torchvision.transforms import ToTensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x20d7eab4810>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.manual_seed(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download training data from open datasets.\n",
    "training_data = datasets.FashionMNIST(\n",
    "    root=\"data\",\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=ToTensor(),\n",
    ")\n",
    "\n",
    "# Download test data from open datasets.\n",
    "test_data = datasets.FashionMNIST(\n",
    "    root=\"data\",\n",
    "    train=False,\n",
    "    download=True,\n",
    "    transform=ToTensor(),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])\n",
      "Shape of y: torch.Size([64]) torch.int64\n"
     ]
    }
   ],
   "source": [
    "batch_size = 64\n",
    "\n",
    "# Create data loaders.\n",
    "train_dataloader = DataLoader(training_data, batch_size=batch_size)\n",
    "test_dataloader = DataLoader(test_data, batch_size=batch_size)\n",
    "\n",
    "train_dataloader2 = DataLoader(training_data, batch_size=batch_size)\n",
    "test_dataloader2 = DataLoader(test_data, batch_size=batch_size)\n",
    "\n",
    "for X, y in test_dataloader:\n",
    "    print(f\"Shape of X [N, C, H, W]: {X.shape}\")\n",
    "    print(f\"Shape of y: {y.shape} {y.dtype}\")\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using cpu device\n",
      "NeuralNetwork(\n",
      "  (flatten): Flatten(start_dim=1, end_dim=-1)\n",
      "  (linear_relu_stack): Sequential(\n",
      "    (0): Linear(in_features=784, out_features=512, bias=True)\n",
      "    (1): ReLU()\n",
      "    (2): Linear(in_features=512, out_features=512, bias=True)\n",
      "    (3): ReLU()\n",
      "    (4): Linear(in_features=512, out_features=10, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# Get cpu or gpu device for training.\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n",
    "print(f\"Using {device} device\")\n",
    "\n",
    "# Define model\n",
    "class NeuralNetwork(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.linear_relu_stack = nn.Sequential(\n",
    "            nn.Linear(28*28, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, 10)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.flatten(x)\n",
    "        logits = self.linear_relu_stack(x)\n",
    "        return logits\n",
    "\n",
    "model = NeuralNetwork().to(device)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LowRankNeuralNetwork(\n",
      "  (flatten): Flatten(start_dim=1, end_dim=-1)\n",
      "  (linear_relu_stack): Sequential(\n",
      "    (0): LowRankLinearLayer()\n",
      "    (1): ReLU()\n",
      "    (2): LowRankLinearLayer()\n",
      "    (3): ReLU()\n",
      "    (4): LowRankLinearLayer()\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "class LowRankLinearLayer(nn.Module):\n",
    "    def __init__(self, initial_weight, initial_bias, rank):\n",
    "        super(LowRankLinearLayer, self).__init__()\n",
    "        U, S, Vh = torch.linalg.svd(initial_weight, full_matrices=False)\n",
    "        self.U = nn.Parameter(U[:, :rank])\n",
    "        self.S = nn.Parameter(S[:rank])\n",
    "        self.Vh = nn.Parameter(Vh[:rank, :])\n",
    "        self.bias = nn.Parameter(initial_bias)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return torch.einsum('mn, bn -> bm', self.U @ torch.diag(self.S) @ self.Vh, x) + self.bias\n",
    "\n",
    "\n",
    "    \n",
    "class LowRankNeuralNetwork(nn.Module):\n",
    "    def __init__(self, prec):\n",
    "        super().__init__()\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.linear_relu_stack = nn.Sequential(\n",
    "            LowRankLinearLayer(model._modules['linear_relu_stack'][0].weight, model._modules['linear_relu_stack'][0].bias, int(prec * 512)),\n",
    "            nn.ReLU(),\n",
    "            LowRankLinearLayer(model._modules['linear_relu_stack'][2].weight, model._modules['linear_relu_stack'][2].bias, int(prec * 512)),\n",
    "            nn.ReLU(),\n",
    "            LowRankLinearLayer(model._modules['linear_relu_stack'][4].weight, model._modules['linear_relu_stack'][4].bias, int(prec * 10))\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.flatten(x)\n",
    "        logits = self.linear_relu_stack(x)\n",
    "        return logits\n",
    "\n",
    "model2 = LowRankNeuralNetwork(0.6).to(device)\n",
    "print(model2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def count_parameters(model):\n",
    "    total_params = 0\n",
    "    for name, parameter in model.named_parameters():\n",
    "        if not parameter.requires_grad: continue\n",
    "        params = parameter.numel()\n",
    "        print(f'{name}: {params}')\n",
    "        total_params+=params\n",
    "    print(f\"Total Trainable Params: {total_params}\")\n",
    "    return total_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "linear_relu_stack.0.weight: 401408\n",
      "linear_relu_stack.0.bias: 512\n",
      "linear_relu_stack.2.weight: 262144\n",
      "linear_relu_stack.2.bias: 512\n",
      "linear_relu_stack.4.weight: 5120\n",
      "linear_relu_stack.4.bias: 10\n",
      "Total Trainable Params: 669706\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "669706"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "count_parameters(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "linear_relu_stack.0.U: 157184\n",
      "linear_relu_stack.0.S: 307\n",
      "linear_relu_stack.0.Vh: 240688\n",
      "linear_relu_stack.0.bias: 512\n",
      "linear_relu_stack.2.U: 157184\n",
      "linear_relu_stack.2.S: 307\n",
      "linear_relu_stack.2.Vh: 157184\n",
      "linear_relu_stack.2.bias: 512\n",
      "linear_relu_stack.4.U: 60\n",
      "linear_relu_stack.4.S: 6\n",
      "linear_relu_stack.4.Vh: 3072\n",
      "linear_relu_stack.4.bias: 10\n",
      "Total Trainable Params: 717026\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "717026"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "count_parameters(model2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_fn = nn.CrossEntropyLoss()\n",
    "loss_fn2 = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)\n",
    "optimizer2 = torch.optim.SGD(model2.parameters(), lr=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(dataloader, model, loss_fn, optimizer):\n",
    "    size = len(dataloader.dataset)\n",
    "    model.train()\n",
    "    for batch, (X, y) in enumerate(dataloader):\n",
    "        X, y = X.to(device), y.to(device)\n",
    "\n",
    "        # Compute prediction error\n",
    "        pred = model(X)\n",
    "        loss = loss_fn(pred, y)\n",
    "\n",
    "        # Backpropagation\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if batch % 100 == 0:\n",
    "            loss, current = loss.item(), batch * len(X)\n",
    "            print(f\"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(dataloader, model, loss_fn):\n",
    "    size = len(dataloader.dataset)\n",
    "    num_batches = len(dataloader)\n",
    "    model.eval()\n",
    "    test_loss, correct = 0, 0\n",
    "    with torch.no_grad():\n",
    "        for X, y in dataloader:\n",
    "            X, y = X.to(device), y.to(device)\n",
    "            pred = model(X)\n",
    "            test_loss += loss_fn(pred, y).item()\n",
    "            correct += (pred.argmax(1) == y).type(torch.float).sum().item()\n",
    "    test_loss /= num_batches\n",
    "    correct /= size\n",
    "    print(f\"Test Error: \\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1\n",
      "-------------------------------\n",
      "loss: 2.294026  [    0/60000]\n",
      "loss: 2.285374  [ 6400/60000]\n",
      "loss: 2.261674  [12800/60000]\n",
      "loss: 2.258286  [19200/60000]\n",
      "loss: 2.246893  [25600/60000]\n",
      "loss: 2.204551  [32000/60000]\n",
      "loss: 2.224775  [38400/60000]\n",
      "loss: 2.186471  [44800/60000]\n",
      "loss: 2.188018  [51200/60000]\n",
      "loss: 2.143335  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 34.9%, Avg loss: 2.144163 \n",
      "\n",
      "loss: 2.299525  [    0/60000]\n",
      "loss: 2.295399  [ 6400/60000]\n",
      "loss: 2.289398  [12800/60000]\n",
      "loss: 2.287983  [19200/60000]\n",
      "loss: 2.290339  [25600/60000]\n",
      "loss: 2.271779  [32000/60000]\n",
      "loss: 2.283621  [38400/60000]\n",
      "loss: 2.272868  [44800/60000]\n",
      "loss: 2.274235  [51200/60000]\n",
      "loss: 2.261751  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 24.1%, Avg loss: 2.262018 \n",
      "\n",
      "Epoch 2\n",
      "-------------------------------\n",
      "loss: 2.155175  [    0/60000]\n",
      "loss: 2.143347  [ 6400/60000]\n",
      "loss: 2.081628  [12800/60000]\n",
      "loss: 2.104539  [19200/60000]\n",
      "loss: 2.053961  [25600/60000]\n",
      "loss: 1.977515  [32000/60000]\n",
      "loss: 2.022094  [38400/60000]\n",
      "loss: 1.939490  [44800/60000]\n",
      "loss: 1.949365  [51200/60000]\n",
      "loss: 1.860123  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 54.5%, Avg loss: 1.868676 \n",
      "\n",
      "loss: 2.267310  [    0/60000]\n",
      "loss: 2.263679  [ 6400/60000]\n",
      "loss: 2.246724  [12800/60000]\n",
      "loss: 2.252497  [19200/60000]\n",
      "loss: 2.253488  [25600/60000]\n",
      "loss: 2.214979  [32000/60000]\n",
      "loss: 2.251369  [38400/60000]\n",
      "loss: 2.225450  [44800/60000]\n",
      "loss: 2.228664  [51200/60000]\n",
      "loss: 2.212713  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 28.9%, Avg loss: 2.211816 \n",
      "\n",
      "Epoch 3\n",
      "-------------------------------\n",
      "loss: 1.897480  [    0/60000]\n",
      "loss: 1.865716  [ 6400/60000]\n",
      "loss: 1.747558  [12800/60000]\n",
      "loss: 1.798576  [19200/60000]\n",
      "loss: 1.684831  [25600/60000]\n",
      "loss: 1.627527  [32000/60000]\n",
      "loss: 1.657716  [38400/60000]\n",
      "loss: 1.567649  [44800/60000]\n",
      "loss: 1.586581  [51200/60000]\n",
      "loss: 1.472576  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 63.1%, Avg loss: 1.498446 \n",
      "\n",
      "loss: 2.223347  [    0/60000]\n",
      "loss: 2.219634  [ 6400/60000]\n",
      "loss: 2.186698  [12800/60000]\n",
      "loss: 2.200322  [19200/60000]\n",
      "loss: 2.202906  [25600/60000]\n",
      "loss: 2.132553  [32000/60000]\n",
      "loss: 2.198699  [38400/60000]\n",
      "loss: 2.152831  [44800/60000]\n",
      "loss: 2.153635  [51200/60000]\n",
      "loss: 2.131352  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 35.8%, Avg loss: 2.130316 \n",
      "\n",
      "Epoch 4\n",
      "-------------------------------\n",
      "loss: 1.554919  [    0/60000]\n",
      "loss: 1.523865  [ 6400/60000]\n",
      "loss: 1.375502  [12800/60000]\n",
      "loss: 1.459478  [19200/60000]\n",
      "loss: 1.327151  [25600/60000]\n",
      "loss: 1.328126  [32000/60000]\n",
      "loss: 1.344154  [38400/60000]\n",
      "loss: 1.281542  [44800/60000]\n",
      "loss: 1.307068  [51200/60000]\n",
      "loss: 1.204128  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 63.9%, Avg loss: 1.232651 \n",
      "\n",
      "loss: 2.150195  [    0/60000]\n",
      "loss: 2.144425  [ 6400/60000]\n",
      "loss: 2.083155  [12800/60000]\n",
      "loss: 2.105351  [19200/60000]\n",
      "loss: 2.114027  [25600/60000]\n",
      "loss: 1.999209  [32000/60000]\n",
      "loss: 2.090614  [38400/60000]\n",
      "loss: 2.022357  [44800/60000]\n",
      "loss: 2.006956  [51200/60000]\n",
      "loss: 1.969036  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 42.6%, Avg loss: 1.970297 \n",
      "\n",
      "Epoch 5\n",
      "-------------------------------\n",
      "loss: 1.295887  [    0/60000]\n",
      "loss: 1.287649  [ 6400/60000]\n",
      "loss: 1.121291  [12800/60000]\n",
      "loss: 1.243374  [19200/60000]\n",
      "loss: 1.096670  [25600/60000]\n",
      "loss: 1.132355  [32000/60000]\n",
      "loss: 1.155518  [38400/60000]\n",
      "loss: 1.103880  [44800/60000]\n",
      "loss: 1.135899  [51200/60000]\n",
      "loss: 1.048307  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 65.5%, Avg loss: 1.070513 \n",
      "\n",
      "loss: 2.004038  [    0/60000]\n",
      "loss: 1.987362  [ 6400/60000]\n",
      "loss: 1.868093  [12800/60000]\n",
      "loss: 1.900870  [19200/60000]\n",
      "loss: 1.920072  [25600/60000]\n",
      "loss: 1.751031  [32000/60000]\n",
      "loss: 1.856919  [38400/60000]\n",
      "loss: 1.775727  [44800/60000]\n",
      "loss: 1.730846  [51200/60000]\n",
      "loss: 1.651195  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 45.1%, Avg loss: 1.672433 \n",
      "\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "epochs = 5\n",
    "for t in range(epochs):\n",
    "    print(f\"Epoch {t+1}\\n-------------------------------\")\n",
    "\n",
    "    train(train_dataloader, model, loss_fn, optimizer)\n",
    "    test(test_dataloader, model, loss_fn)\n",
    "    \n",
    "    train(train_dataloader2, model2, loss_fn2, optimizer2)\n",
    "    test(test_dataloader2, model2, loss_fn2)\n",
    "print(\"Done!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved PyTorch Model State to model.pth\n",
      "Saved PyTorch Model State to model.pth\n"
     ]
    }
   ],
   "source": [
    "torch.save(model.state_dict(), \"model.pth\")\n",
    "print(\"Saved PyTorch Model State to model.pth\")\n",
    "\n",
    "torch.save(model2.state_dict(), \"model2.pth\")\n",
    "print(\"Saved PyTorch Model State to model.pth\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "9b3b15e79d64212a14c381e1bc9a41101994b32312634469deb1a16fd6054240"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
